{% extends 'base.exported.class' %}

{% block content %}
import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken;

import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.lang.reflect.Type;
import java.util.Arrays;
import java.util.List;
import java.util.Scanner;


class {{ class_name }} {

    private class Tree {
        private int[] lefts;
        private int[] rights;
        private double[] thresholds;
        private int[] indices;
        private double[][] classes;

        private double[] normVals(double[] nums) {
            int i = 0;
            int il = nums.length;
            double[] result = new double[nums.length];
            double sum = 0.;
            for (i = 0; i < il; i++) {
                sum += nums[i];
            }
            if(sum == 0) {
                for (i = 0; i < il; i++) {
                    result[i] = 1.0 / il;
                }
            } else {
                for (i = 0; i < il; i++) {
                    result[i] = nums[i] / sum;
                }
            }
            return result;
        };

        private double[] compute(double[] features, int node) {
            while (this.thresholds[node] != -2) {
                if (features[this.indices[node]] <= this.thresholds[node]) {
                    node = this.lefts[node];
                } else {
                    node = this.rights[node];
                }
            }
            return this.normVals(this.classes[node]);
        }

        private double[] compute(double[] features) {
            return this.compute(features, 0);
        }
    }

    private List<Tree> forest;
    private int nTrees;
    private int nClasses;

    public {{ class_name }} (String file) throws FileNotFoundException {
        String jsonStr = new Scanner(new File(file)).useDelimiter("\\Z").next();
        Gson gson = new Gson();
        Type listType = new TypeToken<List<Tree>>(){}.getType();
        this.forest = gson.fromJson(jsonStr, listType);
        this.nTrees = this.forest.size();
        this.nClasses = this.forest.get(0).classes[0].length;
    }

    private static int findMax(double[] nums) {
        int idx = 0;
        for (int i = 0; i < nums.length; i++) {
            idx = nums[i] > nums[idx] ? i : idx;
        }
        return idx;
    }

    private double[] compute(double[] features) {
        double[] probas = new double[this.nClasses];
        for (int i = 0; i < this.nTrees; i++) {
            double[] temp = this.forest.get(i).compute(features);
            for (int j = 0; j < this.nClasses; j++) {
                probas[j] += temp[j];
            }
        }
        for (int j = 0; j < this.nClasses; j++) {
            probas[j] /= this.nTrees;
        }
        return probas;
    }

    public double[] predictProba(double[] features) {
        return this.compute(features);
    }

    public int predict(double[] features) {
        return findMax(this.predictProba(features));
    }

    public static void main(String[] args) throws FileNotFoundException {

        // Features:
        double[] features = new double[args.length-1];
        for (int i = 1, l = args.length; i < l; i++) {
            features[i - 1] = Double.parseDouble(args[i]);
        }

        // Model data:
        String modelData = args[0];

        // Estimator:
        {{ class_name }} clf = new {{ class_name }}(modelData);

        {% if is_test or to_json %}
        // Get JSON:
        int prediction = clf.predict(features);
        double[] probabilities = clf.predictProba(features);
        System.out.println("{\"predict\": " + String.valueOf(prediction) + ", \"predict_proba\": " + String.join(",", Arrays.toString(probabilities)) + "}");
        {% else %}
        // Get class prediction:
        int prediction = clf.predict(features);
        System.out.println("Predicted class: #" + String.valueOf(prediction));

        // Get class probabilities:
        double[] probabilities = clf.predictProba(features);
        for (int i = 0; i < probabilities.length; i++) {
            System.out.print(String.valueOf(probabilities[i]));
            if (i != probabilities.length - 1) {
                System.out.print(",");
            }
        }
        {% endif %}

    }
}
{% endblock %}